

clearvars -except MAP


%% data (including out of sample)
load('../../true_DGP/priorWSigma.mat')
data=readtable('../../../Data/data_main_sample.csv');

data=data(data.year>=1950&isfinite(data.D),:);
ccodes=unique(data.ccode);
years=unique(data.year(data.year>1950));
N=length(ccodes);
T=length(years);
K=0;
y=zeros(N,T); D=y; M=y; X=repmat(y,1,1,K);
for n=1:N
    dn=data(data.ccode==ccodes(n),:);
    for t=1:T
        if isempty(setdiff(years(t),dn.year))
            if isfinite(dn.y(years(t)==dn.year))
            	M(n,t)=1; % not missing
                y(n,t)=dn.y(years(t)==dn.year);
                D(n,t)=dn.D(years(t)==dn.year);
            end
        end
    end
end
clear dn t

dd=readtable('../../../Data/distance_capitals.csv');
Z=10^(-6)*table2array(dd(:,2:end));
KZ=size(Z,3);

ep_sd=sqrt(diag(Sigma));

% Sparse-grid integration
[sgi_nodes,sgi_weights]=nwspgr('KPN',2,8);

cnames=cell(N,1);
for n=1:N
    a=unique(data.idacr(data.ccode==ccodes(n)));
    cnames{n}=a{1};
end
ccodes=table(ccodes,cnames);
% World Regions:
% America: 2-165; Europe: 200-255,305,325,375-390,290,310-317,339-370;
% Africa: 404-625,651; 
% Asia & Oceania: 371-373,630-645,652-705,750-771,780-790,710-740,775,800-920
ccodes.region=ccodes.cnames;
for n=1:size(ccodes,1)
    if isempty(setdiff(ccodes.ccodes(n),2:165))
        ccodes.region(n)={'AME'};
    elseif isempty(setdiff(ccodes.ccodes(n),[200:255,290,305,310:317,325,339:370,375:390]))
        ccodes.region(n)={'EUR'};
    elseif isempty(setdiff(ccodes.ccodes(n),[404:625,651]))
        ccodes.region(n)={'AFR'};
    elseif isempty(setdiff(ccodes.ccodes(n),[371:373,630:645,652:705,710:740,750:771,775,780:790,800:920]))
        ccodes.region(n)={'ASO'};
    end
end

clear a cnames data dd n

R=dummyvar(categorical(ccodes.region)); % regions
Kth=size(R,2);


%% parameter estimates
alpha=MAP(1:N,1);
theta=MAP(N+(1:2*Kth),1);
theta=reshape(theta,Kth,2);
beta=MAP(N+2*Kth+(1:2*N),1);
v=MAP(3*N+2*Kth+(1:N),1);
f=MAP(4*N+2*Kth+(1:N),1);
vs=MAP(5*N+2*Kth+(1:N),1);
xi=MAP(6*N+2*Kth+(1:K),1);
gamma=MAP(6*N+2*Kth+K+(1:KZ),1);


%% log-likelihood
ll=-log_posterior_thetas(MAP,D(:,1:50),M(:,1:50),X(:,1:50,:),Z,R,a0,om_a,th0,om_th,b0A,b0D,om_b,s_v,d_v,f0,om_f,s_vs,d_vs)-...
    (sum(-0.5*((alpha-a0)/om_a).^2)+sum(-0.5*((reshape(theta,2*Kth,1)-th0)/om_th).^2)+...
    sum(-0.5*((beta(1:N)-b0A)/om_b).^2)+sum(-0.5*((beta(N+(1:N))-b0D)/om_b).^2)+...
    sum(-(s_v+1)*log(v)-d_v*v.^(-1))+sum(-0.5*((f-f0)/om_f).^2)+...
    sum(-(s_vs+1)*log(vs)-d_vs*vs.^(-1)));


%% one-step-ahead forecasts
ZZ=zeros(N,N);
for kz=1:KZ
    ZZ=ZZ+Z(:,:,kz)*gamma(kz);
end 
Q=diag(v.*ep_sd)*exp(-ZZ)*diag(v.*ep_sd);
P=kron(eye(2),tril(Q)+tril(Q,-1)');
P=P\eye(2*N);

XX=zeros(N,T);
Pd=repmat(sqrt(diag(P(:,:,1)\eye(2*N))),1,T);
B=repmat(beta,1,T);
for t=1:T
    XX(:,t)=permute(X(:,t,:),[1,3,2])*xi;
end

NS=length(sgi_weights);
U0=zeros(N*T,NS);
U1=zeros(N*T,NS);

for ns=1:NS
	U0(:,ns)=reshape(exp(alpha+(R*theta(:,1)).*(sgi_nodes(ns,1)*Pd(1:N,:)+B(1:N,:)+sgi_nodes(ns,2)*ep_sd)),N*T,1);
    U1(:,ns)=reshape(exp(alpha+(R*theta(:,2)).*(sgi_nodes(ns,1)*Pd(N+(1:N),:)+B(N+(1:N),:)+sgi_nodes(ns,2)*ep_sd)-f-XX),N*T,1);
end
U0=reshape(M,N*T,1).*((U0./(1+U0))*sgi_weights);
U1=reshape(M,N*T,1).*((U1./(1+U1))*sgi_weights);

D_hat=1*reshape(U1>=U0,N,T);


%% observed and predicted transitions into or out of democracy
trns=zeros(N,T);
trns_h=trns;
for n=1:N
    for t=2:T
        if M(n,t-1)==1&&M(n,t)==1
            trns(n,t)=D(n,t)-D(n,t-1);
            trns_h(n,t)=D_hat(n,t)-D_hat(n,t-1);
        end
    end
end

% correct predictions at exact and 5-year windows
correct0=0; correct5=0;
for n=1:N
    for t=2:50
        if trns(n,t)~=0
            correct0=correct0+1*(trns_h(n,t)==trns(n,t));
            correct5=correct5+1*(sum(trns_h(n,max(2,t-2):t+2)==trns(n,t))>0);
        end
    end
end


%% print results
disp('for Table A10: no-learning model with heterogeneous elite turnover')
disp(' ')
disp(['observations = ',num2str(sum(sum(M(:,1:50))))])
disp(['log-likelihood = ',num2str(ll)])
disp(['% correctly predicted choices = ',num2str(sum(sum(M(:,1:50).*(D(:,1:50)==D_hat(:,1:50))))/sum(sum(M(:,1:50))))])
disp(['% correctly predicted transitions within 0 years = ',num2str(correct0/sum(sum(abs(trns(:,1:50)))))])
disp(['% correctly predicted transitions within 5 years = ',num2str(correct5/sum(sum(abs(trns(:,1:50)))))])
disp(' ')
disp(' ')



